# coding: utf-8
# -------------------------------------------------------------------
# aaPanel
# -------------------------------------------------------------------
# Copyright (c) 2015-2099 aaPanel(www.aapanel.com) All rights reserved.
# -------------------------------------------------------------------
# Author: hwliang <hwl@aapanel.com>
# -------------------------------------------------------------------

# ------------------------------
# HTTP代理模块
# ------------------------------

import os
import re
import socket
import time
from http.cookies import SimpleCookie

import requests
import urllib3.util.connection as urllib3_conn

from BTPanel import request, Response, public, app, get_phpmyadmin_dir, session


class HttpProxy:
    _pma_path = None

    @staticmethod
    def _err_resp(msg: str = None):
        return Response(
            msg or "something wrong with socket, please cheak and try again...", 500
        )

    def get_res_headers(self, p_res):
        """
            @name 获取响应头
            @author hwliang<2022-01-19>
            @param p_res<Response> requests响应对像
            @return dict
        """
        headers = {}
        for h in p_res.headers.keys():
            if h in ['content-encoding', 'Content-Encoding', 'transfer-encoding', 'Transfer-Encoding']:
                continue
            headers[h] = p_res.headers[h]
            if h in ['location', 'Location']:

                # ============  redirect ===================
                # phpmyadmin
                if headers[h].find('phpmyadmin_') != -1:
                    if not self._pma_path:
                        self._pma_path = get_phpmyadmin_dir()
                        if self._pma_path:
                            self._pma_path = self._pma_path[0]
                        else:
                            self._pma_path = ''
                    headers[h] = headers.get(h, "").replace(self._pma_path, 'phpmyadmin')
                # adminer
                elif headers[h].find("adminer_") != -1:
                    from adminer.manager import AdminerManager
                    adminer_dir, _ = AdminerManager().adminer_dir_port
                    headers[h] = headers.get(h, "").replace(adminer_dir, 'adminer')
                # ============  redirect end ==================

                if headers[h].find('127.0.0.1') != -1:
                    headers[h] = re.sub(r"https?://127.0.0.1(:\d+)?/", request.url_root, headers[h])
                if request.url_root.find('https://') == 0:
                    headers[h] = headers.get(h, '').replace('http://', 'https://')
        return headers

    def set_res_headers(self, res, p_res):
        """
            @name 设置响应头
            @author hwliang<2022-01-19>
            @param res<Response> flask响应对像
            @param p_res<Response> requests响应对像
            @return res<Response>
        """
        # from datetime import datetime
        # cookie_dict = p_res.cookies.get_dict()
        # expires = datetime.utcnow() + app.permanent_session_lifetime
        # for k in cookie_dict.keys():
        #     httponly = True
        #     if k in ['phpMyAdmin']: httponly = True
        #     res.set_cookie(k, cookie_dict[k],
        #                         expires=expires, httponly=httponly,
        #                         path='/')

        return res

    def get_pma_phpversion(self):
        """
            @name 获取phpmyadmin的php版本
            @author hwliang<2022-01-19>
            @return str
        """
        from panelPlugin import panelPlugin
        pma_status = panelPlugin().getPHPMyAdminStatus()
        if 'phpversion' in pma_status:
            return pma_status['phpversion']
        return None

    def get_pma_version(self):
        """
            @name 获取phpmyadmin的版本
            @author hwliang<2022-01-19>
            @return str
        """
        pma_vfile = public.get_setup_path() + '/phpmyadmin/version.pl'
        if not os.path.exists(pma_vfile): return ''
        pma_version = public.readFile(pma_vfile).strip()
        if not pma_version: return ''
        return pma_version

    def set_pma_phpversion(self):
        """
            @name 设置phpmyadmin兼容的php版本
            @author hwliang<2022-01-19>
            @return str
        """

        pma_version = self.get_pma_version()
        if not pma_version: return False

        old_phpversion = self.get_pma_phpversion()
        if not old_phpversion: return False
        if pma_version == '4.0':
            php_versions = ['52', '53', '54']
        elif pma_version == '4.4':
            php_versions = ['54', '55', '56']
        elif pma_version == '4.9':
            php_versions = ['55', '56', '70', '71', '72', '73', '74']
        elif pma_version == '5.0':
            php_versions = ['70', '71', '72', '73', '74']
        elif pma_version == '5.1':
            php_versions = ['71', '72', '73', '74', '80']
        elif pma_version == '5.2':
            php_versions = ['72', '73', '74', '80', '81']
        elif pma_version == '5.3':
            php_versions = ['72', '73', '74', '80', '81']
        else:
            return False

        if old_phpversion in php_versions: return True

        installed_php_versions = []
        php_install_path = '/www/server/php'
        for version in php_versions:
            php_bin = php_install_path + '/' + version + '/bin/php'
            if os.path.exists(php_bin):
                installed_php_versions.append(version)

        if not installed_php_versions: return False

        php_version = installed_php_versions[-1]

        import ajax
        args = public.dict_obj()
        args.phpversion = php_version
        ajax.ajax().setPHPMyAdmin(args)
        public.WriteLog(
            'Database',
            'The PHP version used by phpMyAdmin has been detected to be incompatible and has been automatically changed to the best compatible version: PHP-' + php_version
        )
        time.sleep(0.5)

    def get_request_headers(self):
        """
            @name 获取请求头
            @author hwliang<2022-01-19>
            @return dict
        """
        headers = {}
        rm_cookies = [app.config['SESSION_COOKIE_NAME'], 'bt_user_info', 'file_recycle_status', 'ltd_end',
                      'memSize', 'page_number', 'pro_end', 'request_token', 'serverType', 'site_model',
                      'sites_path', 'soft_remarks', 'load_page', 'Path', 'distribution', 'order']
        for k in request.headers.keys():
            headers[k] = request.headers.get(k)
            if k == 'Cookie':
                # noinspection PyUnresolvedReferences
                cookie_dict = SimpleCookie(headers[k])
                for rm_cookie in rm_cookies:
                    if rm_cookie in cookie_dict:
                        del (cookie_dict[rm_cookie])
                headers[k] = cookie_dict.output(header='', sep=';').strip()
        return headers

    def form_to_dict(self, form):
        """
            @name 将表单转为字典
            @author hwliang<2022-02-18>
            @param form<request.form> 表单数据
            @return dict
        """

        data = {}
        for k in form.keys():
            data[k] = form.getlist(k)
            if len(data[k]) == 1: data[k] = data[k][0]
        return data

    def proxy(self, proxy_url: str, allow_redirects: bool = False):
        """
            @name 代理指定URL地址
            @author hwliang<2022-01-19>
            @param proxy_url<string> 被代理的URL地址
            @return Response
        """
        try:
            urllib3_conn.allowed_gai_family = lambda: socket.AF_INET
            s_key = 'proxy_{}_{}'.format(app.secret_key, self.get_pma_version())

            if not s_key in session:
                session[s_key] = requests.Session()
                session[s_key].keep_alive = False
                session[s_key].headers = {
                    'User-Agent': 'BT-Panel',
                    'Connection': 'close'
                }

                if proxy_url.find('phpmyadmin') != -1:
                    if proxy_url.find('https://') == 0:
                        session[s_key].cookies.update({'pma_lang_https': 'zh_CN'})
                    else:
                        session[s_key].cookies.update({'pma_lang': 'zh_CN'})
                    self.set_pma_phpversion()

            if 'Authorization' in request.headers:
                session[s_key].headers['Authorization'] = request.headers['Authorization']

            try:
                session[s_key].headers['Host'] = public.en_punycode(
                    request.url_root
                ).replace('http://', '').replace('https://', '').split('/')[0]
            except:
                pass
            headers = None
            if request.method == 'GET':
                # 转发GET请求
                p_res = session[s_key].get(
                    proxy_url, headers=headers, verify=False, allow_redirects=allow_redirects
                )
            elif request.method == 'POST':
                # 转发POST请求
                if request.files:  # 如果上传文件
                    tmp_path = '{}/tmp'.format(public.get_panel_path())
                    if not os.path.exists(tmp_path): os.makedirs(tmp_path, 384)

                    # 遍历form表单中的所有文件
                    files = {}
                    f_list = {}
                    for key in request.files:
                        upload_files = request.files.getlist(key)
                        filename = upload_files[0].filename
                        if not filename: filename = public.GetRandomString(12)
                        tmp_file = '{}/{}'.format(tmp_path, filename)

                        # 保存上传文件到临时目录
                        with open(tmp_file, 'wb') as f:
                            for tmp_f in upload_files:
                                f.write(tmp_f.read())
                            f.close()

                        # 构造文件上传对象
                        f_list[key] = open(tmp_file, 'rb')
                        files[key] = (filename, f_list[key])

                        # 删除临时文件
                        if os.path.exists(tmp_file): os.remove(tmp_file)

                    # 转发上传请求

                    p_res = session[s_key].post(
                        proxy_url,
                        self.form_to_dict(request.form),
                        headers=headers,
                        files=files,
                        verify=False,
                        allow_redirects=allow_redirects
                    )

                    # 释放文件对象
                    for fkey in f_list.keys():
                        f_list[fkey].close()
                else:
                    p_res = session[s_key].post(
                        proxy_url,
                        self.form_to_dict(request.form),
                        headers=headers,
                        verify=False,
                        allow_redirects=allow_redirects
                    )
            else:
                return Response('不支持的请求类型', 500)

            # PHP版本自动切换处理
            if proxy_url.find('phpmyadmin') != -1 and proxy_url.find('/index.php') != -1:
                if len(p_res.content) < 1024:
                    if p_res.content.find(b'syntax error, unexpected') != -1 or p_res.content.find(
                            b'offset access syntax with') != -1 or p_res.content.find(b'+ is required') != -1:
                        self.set_pma_phpversion()
                        return 'Incompatible PHP version, an attempt has been made to automatically switch to a compatible PHP version, please refresh the page and try again!'
                elif p_res.content.find(b'<strong>Deprecation Notice</strong>') != -1 and not session.get(
                        'set_pma_phpversion'):
                    self.set_pma_phpversion()
                    session['set_pma_phpversion'] = True
                    return 'Incompatible PHP version, an attempt has been made to automatically switch to a compatible PHP version, please refresh the page and try again!'

            res = Response(
                p_res.content,
                headers=self.get_res_headers(p_res),
                content_type=p_res.headers.get('content-type', None),
                status=p_res.status_code
            )
            res = self.set_res_headers(res, p_res)
            return res
        except Exception as ex:
            err_msg = re.sub(r"adminer_\S+", "adminer_...", str(ex))
            err_msg = re.sub(r"phpmyadmin_\S+", "phpmyadmin_...", err_msg)
            return Response(err_msg, 500)

    # todo未完善
    def proxy_socket(self, proxy_url: str, allow_redirects: bool = False):
        """
        @name socket代理
        @param proxy_url http+unix://<socket_path>/<request_uri>
        @return Response
        """
        try:
            if not proxy_url.startswith("http+unix://"):
                return self._err_resp(
                    "Socket proxy error: proxy_url format error. It should start with 'http+unix://'"
                )

            from urllib.parse import urlparse, urlunparse, quote
            try:
                from requests_unixsocket import Session as ux_Session
            except ImportError:
                public.ExecShell("btpip install requests_unixsocket")
                try:
                    # noinspection PyUnresolvedReferences
                    from requests_unixsocket import Session as ux_Session
                except:
                    return self._err_resp("The 'requests_unixsocket' module is not installed")
            parsed_url = urlparse(proxy_url)
            if parsed_url.scheme == "http+unix":
                full_path = parsed_url.netloc + parsed_url.path
                if not full_path.startswith("/"):
                    full_path = "/" + full_path

                socket_ext = ".sock"
                socket_pos = full_path.find(socket_ext)
                if socket_pos != -1:
                    socket_path_end = socket_pos + len(socket_ext)
                    socket_path = full_path[:socket_path_end]
                    request_uri = full_path[socket_path_end:]
                    if not request_uri:
                        request_uri = "/"

                    encoded_socket_path = quote(socket_path, safe="")
                    # format: http+unix://<socket_path>/<request_uri>
                    proxy_url = urlunparse((
                        parsed_url.scheme,
                        encoded_socket_path,
                        request_uri,
                        parsed_url.params,
                        parsed_url.query,
                        parsed_url.fragment
                    ))
                else:
                    return self._err_resp("Socket proxy error: Invalid socket proxy URL format.")

            sess = ux_Session()
            headers = self.get_request_headers()
            if request.method == "GET":
                p_res = sess.get(
                    proxy_url, headers=headers, timeout=10, allow_redirects=allow_redirects
                )
            elif request.method == "POST":
                data = self.form_to_dict(request.form)
                files = None
                if request.files:
                    files = {}
                    for key in request.files:
                        fs = request.files.getlist(key)[0]
                        files[key] = (fs.filename, fs.stream)

                p_res = sess.post(
                    proxy_url, data=data, files=files, headers=headers, timeout=10, allow_redirects=allow_redirects
                )
            else:
                return self._err_resp(f"Unsupported method: {request.method}")

            return Response(
                p_res.content,
                status=p_res.status_code,
                headers=headers,
                content_type=p_res.headers.get("content-type", None)
            )

        except Exception as ex:
            return self._err_resp(f"Socket proxy error: {str(ex)}")
